
import csv
from torch.utils.data import Dataset, DataLoader

class NameCountry(Dataset):
    def __init__(self, is_train=True):
        super().__init__()
        file_path = './../data/name2country/names_train.csv' if is_train else './../data/name2country/names_test.csv'
        print(file_path)
        with open(file_path, 'r',encoding='utf8') as f:
            reader = csv.reader(f)
            train_list = list(reader)

        self.names_list = [name for name, _ in train_list]
        self.countrys = [country for _, country in train_list]
        self.country_list = list(sorted(set(self.countrys)))
        self.country_dict = self.get_CountryDict()


    def get_CountryDict(self):
        country_dict = {}
        for i , country in enumerate(self.country_list):
            country_dict[country] = i

        return country_dict

    def __getitem__(self, index):
        return self.names_list[index], self.country_dict[self.countrys[index]]
    def __len__(self):
        return len(self.names_list)


names = NameCountry(True)
print(names.names_list)